# The model implementation is adopted from the dgllife library
# Copyright (C) 2021 THL A29 Limited, a Tencent company.  All rights reserved.
# The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications").
# All Tencent Modifications are Copyright (C) THL A29 Limited.

import numpy as np
import torch
import torch.nn as nn
from dgl.nn.pytorch import CFConv
from dgl.nn.pytorch.conv.cfconv import ShiftedSoftplus
from dgllife.model import MLPNodeReadout

from ..builder import BACKBONES

__all__ = ['SchNet']


class RBFExpansion(nn.Module):
    r"""Expand distances between nodes by radial basis functions.

    .. math::
        \exp(- \gamma * ||d - \mu||^2)

    where :math:`d` is the distance between two nodes and :math:`\mu` helps centralizes
    the distances. We use multiple centers evenly distributed in the range of
    :math:`[\text{low}, \text{high}]` with the difference between two adjacent centers
    being :math:`gap`.

    The number of centers is decided by :math:`(\text{high} - \text{low}) / \text{gap}`.
    Choosing fewer centers corresponds to reducing the resolution of the filters.

    Parameters
    ----------
    low : float
        Smallest center. Default to 0.
    high : float
        Largest center. Default to 30.
    gap : float
        Difference between two adjacent centers. :math:`\gamma` will be computed as the
        reciprocal of gap. Default to 0.1.
    """

    def __init__(self, low=0., high=30., gap=0.1):
        super(RBFExpansion, self).__init__()

        num_centers = int(np.ceil((high - low) / gap))
        self.centers = np.linspace(low, high, num_centers)
        self.centers = nn.Parameter(torch.tensor(self.centers).float(), requires_grad=False)
        self.gamma = 1 / gap

    def reset_parameters(self):
        """Reinitialize model parameters."""
        device = self.centers.device
        self.centers = nn.Parameter(
            self.centers.clone().detach().float(), requires_grad=False).to(device)

    def forward(self, edge_dists):
        """Expand distances.

        Parameters
        ----------
        edge_dists : float32 tensor of shape (E, 1)
            Distances between end nodes of edges, E for the number of edges.

        Returns
        -------
        float32 tensor of shape (E, len(self.centers))
            Expanded distances.
        """
        radial = edge_dists - self.centers
        coef = - self.gamma
        return torch.exp(coef * (radial ** 2))


class Interaction(nn.Module):
    """Building block for SchNet.

    SchNet is introduced in `SchNet: A continuous-filters convolutional neural network for
    modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.

    This layer combines node and edge features in message passing and updates node
    representations.

    Parameters
    ----------
    node_feats : int
        Size for the input and output node features.
    edge_in_feats : int
        Size for the input edge features.
    hidden_feats : int
        Size for hidden representations.
    """

    def __init__(self, node_feats, edge_in_feats, hidden_feats):
        super(Interaction, self).__init__()

        self.conv = CFConv(node_feats, edge_in_feats, hidden_feats, node_feats)
        self.project_out = nn.Linear(node_feats, node_feats)

    def reset_parameters(self):
        """Reinitialize model parameters."""
        for layer in self.conv.project_edge:
            if isinstance(layer, nn.Linear):
                layer.reset_parameters()
        self.conv.project_node.reset_parameters()
        self.conv.project_out[0].reset_parameters()
        self.project_out.reset_parameters()

    def forward(self, g, node_feats, edge_feats):
        """Performs message passing and updates node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_feats)
            Input node features, V for the number of nodes.
        edge_feats : float32 tensor of shape (E, edge_in_feats)
            Input edge features, E for the number of edges.

        Returns
        -------
        float32 tensor of shape (V, node_feats)
            Updated node representations.
        """
        node_feats = self.conv(g, node_feats, edge_feats)
        return self.project_out(node_feats)


@BACKBONES.register_module()
class SchNet(nn.Module):
    """SchNet.

    SchNet is introduced in `SchNet: A continuous-filters convolutional neural network for
    modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.

    This class performs message passing in SchNet and returns the updated node representations.

    Parameters
    ----------
    node_feats : int
        Size for node representations to learn. Default to 64.
    hidden_feats : list of int
        ``hidden_feats[i]`` gives the size of hidden representations for the i-th interaction
        layer. ``len(hidden_feats)`` equals the number of interaction layers.
        Default to ``[64, 64, 64]``.
    num_node_types : int
        Number of node types to embed. Default to 100.
    cutoff : float
        Largest center in RBF expansion. Default to 30.
    gap : float
        Difference between two adjacent centers in RBF expansion. Default to 0.1.
    """

    def __init__(self, node_feats=64, hidden_feats=None, num_node_types=100, cutoff=30., gap=0.1,
                 predictor_hidden_feats=64, n_tasks=1):
        super(SchNet, self).__init__()

        if hidden_feats is None:
            hidden_feats = [64, 64, 64]

        self.embed = nn.Embedding(num_node_types, node_feats)
        self.rbf = RBFExpansion(high=cutoff, gap=gap)

        n_layers = len(hidden_feats)
        self.gnn_layers = nn.ModuleList()
        for i in range(n_layers):
            self.gnn_layers.append(
                Interaction(node_feats, len(self.rbf.centers), hidden_feats[i]))

        self.readout = MLPNodeReadout(node_feats, predictor_hidden_feats, n_tasks,
                                      activation=ShiftedSoftplus())

    def reset_parameters(self):
        """Reinitialize model parameters."""
        self.embed.reset_parameters()
        self.rbf.reset_parameters()
        for layer in self.gnn_layers:
            layer.reset_parameters()

    def forward(self, input):
        """Performs message passing and updates node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_types : int64 tensor of shape (V)
            Node types to embed, V for the number of nodes.
        edge_dists : float32 tensor of shape (E, 1)
            Distances between end nodes of edges, E for the number of edges.

        Returns
        -------
        node_feats : float32 tensor of shape (V, node_feats)
            Updated node representations.
        """
        g, node_types, edge_dists = input
        node_feats = self.embed(node_types)
        expanded_dists = self.rbf(edge_dists)
        for gnn in self.gnn_layers:
            node_feats = gnn(g, node_feats, expanded_dists)
        return node_feats
